package orm

import (
	"fmt"
	"gofar/packed/cfg"
	"gofar/packed/gerr"

	"gorm.io/driver/mysql"
	"gorm.io/driver/postgres"
	"gorm.io/driver/sqlserver"
	"gorm.io/gorm"
)

var (
	openMap map[string]func() gorm.Dialector
	db      *gorm.DB
)

func init() {
	initOpenMap()
}

func DB() *gorm.DB {
	return db
}

func InitDB() *gerr.Error {
	open, ok := openMap[cfg.Get().Database.Type]
	if !ok {
		return gerr.UnknownDBtype()
	}
	_db, err := gorm.Open(open(), &gorm.Config{})
	if err != nil {
		return gerr.Swrap(&err, "数据库连接错误："+err.Error())
	}
	db = _db
	return nil
}

// init OpenMap
func initOpenMap() {
	openMap = map[string]func() gorm.Dialector{
		"mysql": func() gorm.Dialector {
			conf := cfg.Get().Database
			dsn := conf.User + ":" + conf.Passwd + "@tcp(" + conf.Host + ")/" + conf.Dbname
			return mysql.Open(dsn)
		},
		"postgres": func() gorm.Dialector {
			conf := cfg.Get().Database
			dns := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable TimeZone=Asia/Shanghai", conf.Host, conf.User, conf.Passwd, conf.Dbname, conf.Port)
			return postgres.Open(dns)
		},
		"sqlserver": func() gorm.Dialector {
			conf := cfg.Get().Database
			dsn := fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s", conf.User, conf.Passwd, conf.Host, conf.Port, conf.Dbname)
			return sqlserver.Open(dsn)
		},
	}
}
